import os
import json
import math
import numpy as np
import xml.etree.ElementTree as ET

# 计算平均精度
class DetectionMAP(object):
    def __init__(self, num_classes, iou_threshold=0.5):
        """
        功能: 
            初始化计算平均精度方法
        输入: 
            num_classes   - 预测类别数量
            iou_threshold - 测试交并比值
        输出:
        """
        self.num_classes = num_classes                     # 预测类别数量
        self.iou_threshold = iou_threshold                 # 测试交并比值
        self.count = [0] * self.num_classes                # 数量统计列表
        self.score = [[] for _ in range(self.num_classes)] # 得分统计列表
        
    def update(self, infer, gtbox, gtcls):
        """
        功能: 
            统计各类数量和得分
        输入: 
            infer - 预测结果
            gtbox - 物体边框
            gtcls - 物体类别
        输出:
        """
        # 统计各类数量
        for gtcls_item in gtcls:
            self.count[int(np.array(gtcls_item))] += 1
        
        # 统计各类得分
        visited = [False] * len(gtcls) # 各类访问标识
        for infer_item in infer:
            # 获取预测数据
            pdcls, pdsco, xmin, ymin, xmax, ymax = infer_item.tolist() # 获取预测数据
            pdbox = [xmin, ymin, xmax, ymax]                           # 获取预测边框
            
            # 计算最大边框
            max_index = -1 # 最大交并索引
            max_iou = -1.0 # 最大交并比值
            for i, gtcls_item in enumerate(gtcls): # 遍历真实类别列表
                if int(gtcls_item) == int(pdcls): # 如果真实类别等于预测类别，则计算交并比值
                    iou = self.get_box_iou_xyxy(pdbox, gtbox[i])
                    if iou > max_iou: # 如果交并比值大于最大交并比值，则更新最大交并比值和索引
                        max_index = i
                        max_iou = iou
            
            # 统计各类得分
            if max_iou > self.iou_threshold: # 如果最大交并比值大于测试交并比值
                if not visited[max_index]: # 如果该物体没有被统计，则添加到列表，并设置访问标识
                    self.score[int(pdcls)].append([pdsco, 1.0]) # 添加各类正确正例
                    visited[max_index] = True                   # 设置访问标识为真
                else: # 如果该物体已经被统计，则添加到列表，并设置为成错误正例
                    self.score[int(pdcls)].append([pdsco, 0.0]) # 添加各类错误正例
            else: # 如果最大交并比值不大于测试交并比值，则添加到列表，并设置成错误正例
                self.score[int(pdcls)].append([pdsco, 0.0])     # 添加各类错误正例
        
    def get_box_iou_xyxy(self, box1, box2):
        """
        功能: 
            计算边框交并比值
        输入: 
            box1 - 边界框1
            box2 - 边界框2
        输出:
            iou  - 交并比值
        """
        # 计算交集面积
        x1_min, y1_min, x1_max, y1_max = box1[0], box1[1], box1[2], box1[3]
        x2_min, y2_min, x2_max, y2_max = box2[0], box2[1], box2[2], box2[3]

        x_min = np.maximum(x1_min, x2_min)
        y_min = np.maximum(y1_min, y2_min)
        x_max = np.minimum(x1_max, x2_max)
        y_max = np.minimum(y1_max, y2_max)

        w = np.maximum(x_max - x_min + 1.0, 0)
        h = np.maximum(y_max - y_min + 1.0, 0)

        intersection = w * h # 交集面积

        # 计算并集面积
        s1 = (y1_max - y1_min + 1.0) * (x1_max - x1_min + 1.0)
        s2 = (y2_max - y2_min + 1.0) * (x2_max - x2_min + 1.0)

        union = s1 + s2 - intersection # 并集面积

        # 计算交并比
        iou = intersection / union

        return iou
    
    def get_mAP(self):
        """
        功能:
            计算各类平均精度
        输入:
        输出:
            mAP - 各类平均精度
        """
        # 计算每类精度
        mAP = 0 # 各类平均精度
        cnt = 0 # 各类类别计数
        for score, count in zip(self.score, self.count): # 遍历每类物体
            # 统计正误正例
            if count == 0 or len(score) == 0: # 如果该类数量为0，或得分列表为空，则继续下一个类别
                continue
            tp_list, fp_list = self.get_tp_fp_list(score) # 统计正误正例
            
            # 计算预测的准确率和召回率
            precision = [] # 准确率列表
            recall = []    # 召回率列表
            for tp, fp in zip(tp_list, fp_list):
                precision.append(float(tp) / (tp + fp)) # 添加准确率
                recall.append(float(tp) / count)        # 添加召回率
            
            # 计算平均精度
            AP = 0.0         # 平均精度
            pre_recall = 0.0 # 前召回率
            for i in range(len(precision)): # 遍历正确率列表
                recall_gap = math.fabs(recall[i] - pre_recall) # 计算召回率差值
                if recall_gap > 1e-6: # 如果召回率改变，则计算平均精度，更新前召回率
                    AP += precision[i] * recall_gap # 累加平均精度
                    pre_recall = recall[i]          # 更新前召回率
            
            # 更新各类精度
            mAP += AP # 累加各类精度
            cnt += 1  # 增加类别计数
            
        # 计算平均精度
        mAP = (mAP / float(cnt)) if cnt > 0 else mAP
        
        return mAP

    def get_tp_fp_list(self, score):
        """
        功能:
            对得分列表进行从大到小排序，按排序统计正确正例和错误正例数量
        输入:
            score   - 得分列表
        输出:
            tp_list - 正确正例列表
            fp_list - 错误正例列表
        """
        tp = 0       # 正确正例数量
        fp = 0       # 错误正例数量
        tp_list = [] # 正确正例列表
        fp_list = [] # 错误正例列表
        
        score_list = sorted(score, key=lambda s: s[0], reverse=True) # 对得分列表按从大到小排序
        for (score, label) in score_list:
            tp += int(label)     # 统计正确正例
            tp_list.append(tp)   # 添加正确正例
            fp += 1 - int(label) # 统计错误正例
            fp_list.append(fp)   # 添加错误正例
        
        return tp_list, fp_list
    
##############################################################################################################

object_names = ['Boerner', 'Leconte', 'Linnaeus', 'acuminatus', 'armandi', 'coleoptera', 'linnaeus'] # 物体名称
def get_object_gtcls():
    """
    功能:
        将物体名称映射成物体类别
    输入:
    输出:
        object_gtcls - 物体类别
    """
    object_gtcls = {} # 物体类别字典
    for key, value in enumerate(object_names):
        object_gtcls[value] = key # 将物体名称映射成物体类别
    return object_gtcls

def test(json_path, xmls_path, num_classes, iou_threshold):
    """
    功能:
        测试模型平均精度
    输入:
        json_path     - 预测结果路径
        xmls_path     - 标签目录路径
        num_classes   - 预测类别数量
        iou_threshold - 测试交并比值
    输出:
    """
    # 声明计算方法
    mAP = DetectionMAP(num_classes, iou_threshold)
    
    # 统计预测得分
    json_list = json.load(open(json_path))               # 读取预测结果
    for json_item in json_list: # 遍历预测结果
        # 读取预测文件
        image_name = str(json_item[0])                   # 读取文件名称
        infer = np.array(json_item[1]).astype('float32') # 读取预测结果
        
        # 读取标签文件
        tree = ET.parse(os.path.join(xmls_path, image_name + '.xml')) # 解析文件
        image_w = float(tree.find('size').find('width').text)         # 图像宽度
        image_h = float(tree.find('size').find('height').text)        # 图像高度
        
        object_list = tree.findall('object')                     # 物体列表
        gtbox = np.zeros((len(object_list), 4), dtype='float32') # 物体边框
        gtcls = np.zeros((len(object_list),  ), dtype='int32')   # 物体类别
        
        for i, object_item in enumerate(object_list):
            # 读取物体边框
            x_min = float(object_item.find('bndbox').find('xmin').text) # 物体边框x1
            y_min = float(object_item.find('bndbox').find('ymin').text) # 物体边框y1
            x_max = float(object_item.find('bndbox').find('xmax').text) # 物体边框x2
            y_max = float(object_item.find('bndbox').find('ymax').text) # 物体边框y2
            
            x_min = max(0.0, x_min)
            y_min = max(0.0, y_min)
            x_max = min(x_max, image_w - 1.0)
            y_max = min(y_max, image_h - 1.0)
            
            gtbox[i] = [x_min, y_min, x_max, y_max] # 设置物体边框
            
            # 读取物体类别
            object_name = object_item.find('name').text # 读取物体名称
            gtcls[i] = get_object_gtcls()[object_name]  # 将物体名称映射成物体类别
        
        # 统计预测得分
        mAP.update(infer, gtbox, gtcls)
        
    # 计算平均精度
    mAP_value = mAP.get_mAP() * 100 # 计算平均精度
    print("Detection mAP({:.2f}) = {:.2f}%".format(iou_threshold, mAP_value))